Fast TreeSHAP 논문 리뷰
| Venue | Link |
|---|---|
| NeurIPS 2021 XAI4Debugging Workshop | arXiv |
1 Introduction
Yang [5]의 Fast TreeSHAP은 TreeSHAP의 계산 속도를 더 빠르게 만드는 논문이다. TreeSHAP은 tree-based model에서 SHAP value를 exact하게 계산할 수 있게 만든 중요한 알고리즘이다 [1]. 하지만 논문은 TreeSHAP이 이미 polynomial time으로 개선되었음에도, 실제 산업 규모에서는 여전히 병목이 될 수 있다고 지적한다.
여기서 말하는 산업 규모는 단순히 feature가 조금 많은 정도가 아니다. 수천 개 tree를 가진 model, depth가 큰 tree, 그리고 수천만 개 sample을 설명해야 하는 상황을 말한다. 논문은 예시로 maximum depth 12인 400-tree random forest model에서 2천만 개 sample을 설명하면 100-core server에서도 15시간 정도 걸릴 수 있다고 말한다.
이 병목은 두 가지 문제를 만든다.
- model diagnosis가 늦어진다.
- end user에게 feature reasoning을 제공하는 시간이 길어진다.
예를 들어 subscription propensity model을 생각해보자. 모델이 어떤 고객에게 구독 상품을 추천해야 한다고 판단했다면, 마케팅 팀은 “왜 이 고객이 대상인가?”를 알고 싶어 한다. SHAP value 계산이 너무 느리면, model scoring 이후 사람이 실제 action을 취하는 과정도 늦어진다.
이 논문의 목표는 TreeSHAP을 근본적으로 다른 explanation 방법으로 바꾸는 것이 아니다. TreeSHAP이 계산하는 exact SHAP value는 그대로 유지하면서, 계산 방식을 더 빠르게 만드는 것이다. 논문은 두 알고리즘을 제안한다.
- Fast TreeSHAP v1: memory cost는 유지하면서 평균 실행 시간을 줄인다.
- Fast TreeSHAP v2: tree에만 의존하는 값을 미리 계산해서, sample이 많을 때 더 크게 속도를 줄인다.
논문 abstract 기준으로 Fast TreeSHAP v1은 대략 1.5배, Fast TreeSHAP v2는 대략 2.5배 빠르다. Multi-time model interpretation scenario에서는 Fast TreeSHAP v2가 최대 3배 정도 빠른 결과도 보인다.
3 Background
3.1 SHAP Values
모델을 f라고 하자. 입력 feature 집합은 N이고, 특정 sample x에 대해 prediction f(x)를 설명하고 싶다. SHAP은 다음 additive surrogate model의 coefficient로 feature attribution을 표현한다.
g(z') = \phi_0 + \sum_{i \in N} \phi_i z'_i
여기서 z'_i=1이면 feature i가 observed라는 뜻이고, z'_i=0이면 unknown이라는 뜻이다. \phi_i가 feature i의 SHAP value이다.
SHAP value는 다음과 같이 쓴다.
\phi_i = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!(|N|-|S|-1)!}{|N|!} \left( f_{S \cup \{i\}}(x) - f_S(x) \right)
여기서 f_S(x)는 feature subset S만 알려졌을 때의 model output이다. 문제는 이 식이 모든 subset S를 보아야 한다는 점이다. Feature 수가 많아지면 계산량이 빠르게 커진다.
TreeSHAP은 tree model에서 f_S(x)를 효율적으로 계산하고, 모든 subset을 직접 돌지 않도록 만든다. 핵심은 tree path의 구조를 이용하는 것이다.
3.2 SHAP Values for Trees
Tree model에서는 prediction이 root에서 leaf까지 내려가는 path로 결정된다. Feature가 subset S에 들어 있으면 해당 feature의 split을 실제 sample x에 따라 따라간다. 반대로 split feature가 S에 없으면, 양쪽 branch를 training data cover ratio로 weighted average한다.
논문은 tree를 다음 vector들로 표현한다.
- v: leaf node value
- a,b: left/right child index
- t: split threshold
- d: split feature index
- r: node cover, 즉 해당 node에 들어온 training sample 수
이때 f_S(x)를 계산하는 naive recursive algorithm은 tree node를 따라 내려가며 split feature가 S에 있는지 확인한다. 하지만 SHAP value는 모든 S에 대해 이 계산을 반복해야 하므로 exponential complexity가 된다.
3.3 TreeSHAP
TreeSHAP은 전체 subset을 직접 열거하지 않고, tree path 안에서 가능한 subset들의 비율을 동시에 추적한다. 논문은 TreeSHAP의 복잡도를 다음처럼 정리한다.
O(MTLD^2)
여기서 기호는 다음과 같다.
- M: 설명할 sample 수
- T: tree 수
- L: tree 하나의 최대 leaf 수
- D: tree의 최대 depth
Memory complexity는 다음과 같다.
O(D^2 + |N|)
TreeSHAP은 polynomial time이지만, M, T, L, D가 커지면 여전히 무겁다. 특히 산업에서는 M이 매우 클 수 있다. Fast TreeSHAP 논문은 TreeSHAP을 다시 들여다보고, 실제로 비싼 부분이 어디인지 찾아 줄인다.
4 Fast TreeSHAP
Fast TreeSHAP의 핵심은 TreeSHAP의 계산을 leaf path 단위로 다시 정리하는 것이다. 논문은 path P_k와 해당 leaf value v_k를 사용한다. 또 path에 등장하는 feature set을 D_k라고 둔다.
TreeSHAP은 원래 path에 등장하는 feature subset들을 다루지만, Fast TreeSHAP은 sample x가 path threshold를 만족하는 feature와 만족하지 않는 feature를 나눈다. Path P_k에서 threshold를 만족하지 않는 feature들의 집합을 S_k라고 하자.
S_k = \{d_j: j \in P_k,\ x_{d_j} \notin T_{kj}\}
즉 S_k는 sample x가 해당 leaf path를 실제로 따라가지 못하게 만드는 feature들이다. 이 구분이 Fast TreeSHAP v1의 핵심이다.
4.1 Fast TreeSHAP v1 Algorithm
Fast TreeSHAP v1은 TreeSHAP의 EXTEND/UNWIND 구조를 유지한다. 하지만 모든 path feature를 같은 방식으로 계산하지 않는다. Threshold를 만족하지 않는 feature가 들어간 subset은 leaf로 flow down할 수 없으므로, 굳이 똑같이 계산할 필요가 없다.
원래 TreeSHAP은 path depth |D_k|까지 subset size를 모두 고려한다. Fast TreeSHAP v1은 threshold를 만족하는 feature들만 대상으로 subset size를 추적한다. 따라서 실제로 loop가 도는 범위가 줄어든다.
논문은 Fast TreeSHAP v1의 theoretical complexity가 여전히 다음과 같다고 말한다.
O(MTLD^2)
즉 big-O만 보면 원래 TreeSHAP과 같다. 하지만 평균적으로 loop 길이가 줄고, leaf에서 UNWIND를 호출하는 횟수도 줄어든다. 논문은 평균 running time이 원래 TreeSHAP의 약 25%까지 줄 수 있다고 분석한다. 실험에서는 overhead와 다른 계산 때문에 실제 speedup은 대략 1.5배 정도로 나온다.
이 부분은 중요하다. Fast TreeSHAP v1은 asymptotic complexity를 바꾸는 알고리즘은 아니다. 대신 같은 complexity 안에서 불필요한 계산 범위를 줄이는 최적화이다. 그래서 memory cost도 원래 TreeSHAP과 같다.
O(D^2 + |N|)
실무적으로는 이 장점이 크다. Tree depth가 크거나 sample 수가 많아도 memory behavior가 원래 TreeSHAP과 비슷하기 때문에, 기존 TreeSHAP을 비교적 안전하게 대체할 수 있다.
4.2 Fast TreeSHAP v2 Algorithm
Fast TreeSHAP v2는 v1보다 더 공격적인 trade-off를 한다. 핵심은 pre-computation이다.
논문은 v1에서 leaf node의 UNWIND loop가 여전히 비싸다고 본다. Leaf에 도달하면 path에 있는 feature마다 contribution을 update해야 하고, 이 과정에서 U_{D_k,C} 같은 값이 반복 계산된다. 그런데 U_{D_k,C}는 sample 자체가 아니라 tree path와 subset C에만 의존한다.
따라서 v2는 다음 생각을 사용한다.
Sample마다 반복해서 계산하지 말고, tree가 고정되어 있을 때 가능한 U_{D_k,C}를 미리 계산해두자.
Fast TreeSHAP v2는 두 단계로 나뉜다.
Fast TreeSHAP Prep. 각 leaf path P_k에 대해 가능한 모든 subset C \subseteq D_k의 값을 미리 계산한다. 이 결과를 matrix S에 저장한다.
Fast TreeSHAP Score. 새 sample x가 들어오면, path에서 어떤 feature가 threshold를 만족하는지만 확인하고, 필요한 값을 precomputed matrix S에서 lookup한다.
이 구조는 time-memory trade-off이다. 계산 시간을 줄이는 대신 memory가 늘어난다.
Fast TreeSHAP v2의 time complexity는 general case에서 다음과 같다.
O(TL2^D D + MTLD)
Balanced tree에서는 다음처럼 정리된다.
O(TL2^D + MTLD)
Memory complexity는 다음과 같다.
O(L2^D)
원래 TreeSHAP의 O(MTLD^2)와 비교하면, scoring 단계에서 D factor를 하나 줄인 셈이다. 하지만 prep 단계와 memory가 2^D에 의존한다. 따라서 v2는 depth가 너무 큰 tree에는 부담이 될 수 있다.
논문은 v2가 특히 multi-time usage에 잘 맞는다고 말한다. 예를 들어 model은 한 달에 한 번 학습되지만, 매일 새로운 scoring data가 들어온다고 하자. 이 경우 tree는 그대로이므로 prep result를 한 번 계산해 저장해둘 수 있다. 그 다음 매일 들어오는 sample은 score 단계만 돌면 된다. 이때 v2의 장점이 커진다.
4.3 Fast TreeSHAP Summary
논문은 세 방법의 complexity를 다음처럼 비교한다.
| Method | Time complexity | Space complexity |
|---|---|---|
| TreeSHAP | O(MTLD^2) | O(D^2 + |N|) |
| Fast TreeSHAP v1 | O(MTLD^2) | O(D^2 + |N|) |
| Fast TreeSHAP v2 | O(TL2^D D + MTLD) | O(L2^D) |
정리하면 선택 기준은 다음과 같다.
- v1은 거의 항상 원래 TreeSHAP보다 낫다.
- v2는 sample 수 M이 충분히 크고, tree depth D가 memory 안에서 감당될 때 좋다.
- model은 고정되어 있고 scoring data가 반복적으로 들어오는 경우 v2가 특히 유리하다.
논문은 v2가 유리해지는 rough condition으로 다음을 제시한다.
M > \frac{2^{D+1}}{D}
예를 들어 D=10이면 M>205 정도이고, D=14이면 M>2341 정도이다. 즉 sample 수가 조금만 커져도 v2의 prep cost를 회수할 수 있다. 다만 memory는 여전히 확인해야 한다.
5 Evaluation
논문은 Adult, Superconductor, Crop, LinkedIn 내부 Upsell dataset에서 random forest model을 학습해 평가한다. 각 dataset마다 tree 수는 100개로 고정하고, maximum depth를 4, 8, 12, 16으로 바꾸어 small, medium, large, extra-large model을 만든다.
비교 대상은 다음 세 가지이다.
- Original TreeSHAP
- Fast TreeSHAP v1
- Fast TreeSHAP v2
구현은 SHAP package의 C 파일 treeshap.h를 직접 수정해서 공정하게 비교한다. 각 평가는 10,000 samples, single core 환경에서 수행한다.
논문 Table 3의 핵심 결과는 다음과 같다.
Fast TreeSHAP v1. Medium, large, extra-large model에서 대략 1.4-1.6배 speedup을 보인다. Small model에서는 계산 자체가 작아서 speedup이 1.2배 정도로 낮다.
Fast TreeSHAP v2. Medium, large model에서 대략 2.5-3배 speedup을 보인다. Extra-large model에서는 prep cost가 커져서 speedup이 2배 근처로 낮아진다.
논문은 correctness도 확인한다. Fast TreeSHAP v1/v2가 계산한 SHAP value와 original TreeSHAP의 최대 element-wise difference는 약 10^{-13} 수준이라고 한다. 이는 numerical error로 볼 수 있다. 즉 이 논문의 목표는 approximation이 아니라 exact computation의 acceleration이다.
Table 4는 v2를 Prep과 Score로 나누어 보여준다. Small/medium model에서는 prep cost가 거의 무시 가능하다. 하지만 depth가 커지면 prep cost와 matrix S의 memory가 빠르게 증가한다. 예를 들어 Super-xLarge에서는 S의 space allocation이 1.76GB까지 간다.
그래도 논문은 대부분의 industry tree model이 depth 16을 넘지 않으므로, ordinary laptop에서도 많은 경우 v2가 가능하다고 본다. 그리고 multi-time usage에서는 prep을 한 번만 하면 되므로, score 단계의 speedup이 더 직접적으로 나타난다.
6 Conclusion
Fast TreeSHAP 논문은 TreeSHAP의 theoretical guarantee를 유지하면서 계산 병목을 줄이는 방법을 제시한다. 핵심 기여는 두 가지이다.
첫째, Fast TreeSHAP v1은 path에서 threshold를 만족하는 feature에 계산을 집중한다. Big-O complexity는 그대로지만, 실제 loop 범위를 줄여 평균 실행 시간을 개선한다. Memory cost는 원래 TreeSHAP과 같다.
둘째, Fast TreeSHAP v2는 tree에만 의존하는 값을 pre-computation으로 분리한다. 이 덕분에 scoring 단계의 complexity를 O(MTLD)로 줄인다. 대신 O(L2^D) memory가 필요하다.
따라서 이 논문은 다음 trade-off를 명확히 보여준다.
| Version | 장점 | 비용 |
|---|---|---|
| Fast TreeSHAP v1 | 안전한 drop-in speedup | asymptotic complexity는 그대로 |
| Fast TreeSHAP v2 | sample이 많을 때 더 큰 speedup | prep cost와 memory 증가 |
TreeSHAP-IQ나 SII 글이 “무엇을 설명할 것인가”에 더 가깝다면, Fast TreeSHAP은 “이미 정해진 SHAP value를 어떻게 더 빠르게 계산할 것인가”에 가깝다. 그래서 이 논문은 새로운 attribution 개념을 제안한다기보다, SHAP 기반 explanation을 대규모 환경에서 실제로 돌릴 수 있게 만드는 engineering-oriented algorithm paper로 읽는 것이 좋다.
이 논문에서 중요한 점은 sample 수 M을 complexity 분석에 전면적으로 넣었다는 점이다. 모델 하나를 한 번 설명하는 연구용 setting에서는 TreeSHAP만으로 충분할 수 있다. 하지만 매일 수백만 개 sample을 scoring하고 explanation까지 붙여야 하는 production setting에서는 M이 가장 큰 병목이 된다. Fast TreeSHAP은 바로 그 지점을 겨냥한다.